{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "vocational-assessment",
   "metadata": {},
   "source": [
    "# 手写数字识别\n",
    "\n",
    "\n",
    "### 问题：分类问题（10类）\n",
    "\n",
    "### 输入：灰度图像（28×28个像素）\n",
    "\n",
    "### 输出：分类0-9"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd78c8d8",
   "metadata": {},
   "source": [
    "# 0.超参数设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "513206fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\"\"\"\n",
    "每次在训练集中提取64张图像进行批量化训练，目的是提高训练速度。\n",
    "就好比搬砖，一次搬一块砖头的效率肯定要比一次能搬64块要低得多\n",
    "\"\"\"\n",
    "BATCH_SIZE = 64\n",
    "#学习率，学习率一般为0.01，0.1等等较小的数，为了在梯度下降求解时避免错过最优解\n",
    "LR = 0.001\n",
    "\"\"\"\n",
    "EPOCH 假如现在我有1000张训练图像，因为每次训练是64张，\n",
    "每当我1000张图像训练完就是一个EPOCH，训练多少个EPOCH自己决定\n",
    "\"\"\"\n",
    "EPOCH = 1\n",
    "\"\"\"\n",
    "现在我要训练的训练集是系统自带的，需要先下载数据集，\n",
    "当DOWNLOAD_MNIST为True是表示学要下载数据集，一但下载完，保存\n",
    "然后这个参数就可以改为False，表示不用再次下载\n",
    "\"\"\"\n",
    "DOWNLOAD_MNIST = True\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "statistical-great",
   "metadata": {},
   "source": [
    "# 1.导入数据\n",
    "\n",
    "### 原始数据（来自keras.datasets.mnist）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "monetary-excellence",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 手动导入数据\n",
    "from tensorflow.keras.datasets import mnist\n",
    "\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data() #mnist.load_data('路径')为下载并保存数据集位置，默认位置在C:\\Users\\管理员\\.keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98fb8fad",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "# pytorch导入\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import Dataset, TensorDataset, DataLoader\n",
    "\n",
    "#训练集\n",
    "# 读取\n",
    "train_data = torchvision.datasets.MNIST(\n",
    "    root='./mnist',\n",
    "    train = True,\n",
    "    transform=torchvision.transforms.ToTensor(),\n",
    "    download=DOWNLOAD_MNIST\n",
    ")\n",
    "\n",
    "# 划分\n",
    "train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2 )\n",
    "#每个batch_size的shape为[64, 1, 28, 28]\n",
    "print(\"样本\")\n",
    "print(train_data.train_data.shape)\n",
    "print(train_data.train_data[:3])\n",
    "print(\"标签\")\n",
    "print(train_data.train_labels.shape)\n",
    "print(train_data.train_labels[:3])\n",
    "\n",
    "\n",
    "\n",
    "# 测试集\n",
    "# 读取\n",
    "test_data = torchvision.datasets.MNIST(\n",
    "    root='./mnist',\n",
    "    train = False,\n",
    ")\n",
    "\n",
    "# 处理\n",
    "\n",
    "test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.0\n",
    "\"\"\"\n",
    "test_data.test_data中的shape为[10000, 28, 28]代表1w张图像，都是28x28，当时并未表明channels,因此在unsqueeze在1方向想加一个维度，\n",
    "则shape变为[10000, 1, 28, 28]，然后转化为tensor的float32类型，取1w张中的2000张，并且将其图片进行归一化处理，避免图像几何变换的影响\n",
    "\"\"\"\n",
    "#标签取前2000\n",
    "test_y = test_data.test_labels[:2000]\n",
    "\n",
    "print(\"样本\")\n",
    "print(test_x.shape)\n",
    "print(test_x[:3])\n",
    "print(\"标签\")\n",
    "print(test_y.shape)\n",
    "print(test_y[:3])\n",
    "\n",
    "\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "reverse-craps",
   "metadata": {},
   "source": [
    "# 2.创建自己的Datasets数据集\n",
    "\n",
    "from torch.utils.data import Dataset, TensorDataset ,DataLoader\n",
    "\n",
    "\n",
    ">1.将数据转为tensor格式\n",
    ">>`数据 = torch.tensor(mumpy数据)` \n",
    "\n",
    "\n",
    ">2.数据处理\n",
    ">>图像数据处理：\n",
    ">>>**图像数据列表维度shape：[图像数量,通道维数,图像长像素,图像宽像素]**  \n",
    ">>>缺少通道维黑白图像处理:`图片样本data = Variable(torch.unsqueeze(图片样本data, dim=1), volatile=True).type(torch.FloatTensor)/255`  \n",
    ">>>数据类型转换:`数据变量 = 数据变量.type(torch.FloatTensor)`\n",
    "\n",
    ">>标签处理\n",
    ">>>转换为one-hot编码:`标签labels = utils.to_categorical(标签labels)`  \n",
    ">>>标签转换成long数据格式：`标签labels = 标签labels.long()`\n",
    "\n",
    ">3.创建数据集\n",
    ">>`数据集 = TensorDataset(样本data, 标签labels)`\n",
    "\n",
    ">4.加载数据集\n",
    ">>`train_loader = DataLoader(train_dataset, batch_size=120)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d962467",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "# 对分类标签y进行one-hot编码  utils.to_categorical(标签列表, num_classes=标签类别数, dtype='编码后标签格式')\n",
    "from tensorflow.keras import utils\n",
    "\n",
    "print(\"编码前\")\n",
    "print(train_labels)\n",
    "\n",
    "train_labels = utils.to_categorical(train_labels)\n",
    "test_labels = utils.to_categorical(test_labels)\n",
    "\n",
    "print(\"编码后\")\n",
    "print(train_labels)\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2900642b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  ...\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]]\n",
      "\n",
      " [[0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  ...\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]]\n",
      "\n",
      " [[0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  ...\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]]\n",
      "\n",
      " ...\n",
      "\n",
      " [[0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  ...\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]]\n",
      "\n",
      " [[0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  ...\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]]\n",
      "\n",
      " [[0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  ...\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]\n",
      "  [0 0 0 ... 0 0 0]]]\n",
      "torch.Size([60000, 28, 28])\n",
      "torch.Size([60000, 1, 28, 28])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-3-8705a23add60>:23: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n",
      "  train_images = Variable(torch.unsqueeze(train_images, dim=1), volatile=True).type(torch.FloatTensor)/255\n",
      "<ipython-input-3-8705a23add60>:24: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n",
      "  test_images = Variable(torch.unsqueeze(test_images, dim=1), volatile=True).type(torch.FloatTensor)/255\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import Dataset, TensorDataset, DataLoader\n",
    "\n",
    "print(train_images)\n",
    "\n",
    "# 1.把数据转换成tensor格式\n",
    "train_images = torch.tensor(train_images)\n",
    "train_labels = torch.tensor(train_labels)\n",
    "\n",
    "test_images = torch.tensor(test_images)\n",
    "test_labels = torch.tensor(test_labels)\n",
    "\n",
    "print(train_images.shape)\n",
    "\n",
    "\n",
    "# 2. 数据处理\n",
    "# 将标签转换成long格式\n",
    "train_labels = train_labels.long()\n",
    "test_labels = test_labels.long()\n",
    "\n",
    "# 图像数据调整增加维度 [图片数, 长, 宽]->[图片数, 通道数, 长, 宽], 将数据转为tensor的Float格式\n",
    "train_images = Variable(torch.unsqueeze(train_images, dim=1), volatile=True).type(torch.FloatTensor)/255\n",
    "test_images = Variable(torch.unsqueeze(test_images, dim=1), volatile=True).type(torch.FloatTensor)/255\n",
    "\n",
    "print(train_images.shape)\n",
    "\n",
    "\n",
    "# 3.创建数据集\n",
    "train_dataset = TensorDataset(train_images, train_labels)\n",
    "test_dataset = TensorDataset(test_images, test_labels)\n",
    "\n",
    "\n",
    "# 4.加载数据集\n",
    "train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)\n",
    "test_loader =DataLoader(test_dataset, batch_size=120)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "analyzed-duration",
   "metadata": {},
   "source": [
    "# 3.构建网络\n",
    "\n",
    "```python\n",
    "# 导入包\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# 定义神经网络类\n",
    "class 自定义神经网络类名(nn.Module):\n",
    "    # 可学习参数的层（如全连接层、卷积层等）\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.自定义layer名1 = nn.layer层(参数)\n",
    "        self.自定义layer名2 = nn.layer层(参数)\n",
    "        self.自定义layer名3 = nn.Sequential(\n",
    "            nn.layer层(参数)，\n",
    "            nn.layer层(参数)\n",
    "        ）\n",
    "        \n",
    "    # 实现模型的功能，实现各个层之间的连接关系\n",
    "    # nn.functional实现不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)的构造\n",
    "    def forward(self, x):\n",
    "        x = self.自定义layer名1(x)\n",
    "        x = F.不可学习参数层(x)\n",
    "        x = self.自定义layer名2(x)\n",
    "        x = 不可学习参数层(自定义layer名3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# 实例化神经网络\n",
    "model实例 = 自定义神经网络类名()\n",
    "\n",
    "\n",
    "# 保存模型\n",
    "torch.save(model实例, '存储路径')\n",
    "\n",
    "# 加载模型\n",
    "model = torch.load(\"model.pth\")\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "agreed-marble",
   "metadata": {},
   "source": [
    "#### ·可学习layer\n",
    "\n",
    ">**卷积层**\n",
    ">>nn.Conv2d(\n",
    ">>&nbsp;&nbsp;&nbsp;  in_channels = 输入特征矩阵的深度,   \n",
    ">>&nbsp;&nbsp;&nbsp;  out_channels = 卷积核数|输出特征矩阵深度,   \n",
    ">>&nbsp;&nbsp;&nbsp;  kernel_size = (卷积核长, 卷积核宽),   \n",
    ">>&nbsp;&nbsp;&nbsp;  stride = 卷积框步长,   \n",
    ">>&nbsp;&nbsp;&nbsp;  padding = (填充上下行数, 填充左右列数),   \n",
    ">>&nbsp;&nbsp;&nbsp;  dilation = 卷积核元素之间的间距1,   \n",
    ">>&nbsp;&nbsp;&nbsp;  groups = 从输入通道到输出通道的阻塞连接数1,   \n",
    ">>&nbsp;&nbsp;&nbsp;  bias = 添加偏置T/F,   \n",
    ">>&nbsp;&nbsp;&nbsp;  padding_mode = '填充数字zeros'  \n",
    ">>)\n",
    "\n",
    ">>说明：\n",
    ">>>in_channels = 输入通道维的元素数  \n",
    ">>>图像(通道，图像长，图像宽)（黑白图像通道=1，RGB图像通道=3）   \n",
    ">>>out_channels = 提取特征数 = 输出特征矩阵深度 = 输出特征矩阵通道维\n",
    "\n",
    ">**全连接层**\n",
    ">>nn.Linear(in_features=每个输入样本的大小, out_features=每个输出样本的大小)  \n",
    ">>说明：\n",
    "\n",
    "\n",
    "#### ·不可学习layer\n",
    ">**将多维数据转成一维**\n",
    ">>x.view(x.size(0), -1)  \n",
    ">>说明：在卷积层转全连接层之间使用，用在forward(self, x)中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "binding-pledge",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "# 定义神经网络类\n",
    "class CNN(nn.Module):\n",
    "    # 可学习参数的层（如全连接层、卷积层等）\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        \n",
    "        # 第一部分卷积层1\n",
    "        self.conv1 = nn.Sequential(\n",
    "            # 卷积层(输入通道维1，输出通道维16，卷积窗口3*3)\n",
    "            nn.Conv2d(1, 16, kernel_size=(3,3), stride=1, padding=1),  # 维度变换(1,28,28) （黑白图像1通道，长28像素，宽28像素）->(16,28,28) （16个卷积核提取16个特征通道，长28，宽28）图像边缘扩展，没被卷积抛去\n",
    "            # 激活函数\n",
    "            nn.ReLU(),\n",
    "            # 池化层\n",
    "            nn.MaxPool2d(2) # 维度变化(16,28,28)->(16,14,14)\n",
    "        )\n",
    "            \n",
    "        #第二部分卷积层2\n",
    "        self.conv2 = nn.Sequential(\n",
    "            # 卷积层(输入通道维16，输出通道维32，卷积窗口3*3)\n",
    "            nn.Conv2d(16, 32, kernel_size=(3,3), stride=1, padding=1),  # 维度变换(16,14,14)->(32,14,14)\n",
    "            # 激活函数\n",
    "            nn.ReLU(),\n",
    "            # 池化层\n",
    "            nn.MaxPool2d(2) # 维度变化(32,14,14)->(32,7,7)\n",
    "        )\n",
    "            \n",
    "        # 全连接层\n",
    "        self.out = nn.Linear(32*7*7, 10)\n",
    "            \n",
    "            \n",
    "        \n",
    "    # 实现模型的功能，实现各个层之间的连接关系\n",
    "    # nn.functional实现不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)的构造\n",
    "    def forward(self, x):\n",
    "        # 执行卷积层1 conv1\n",
    "        x = self.conv1(x)\n",
    "        # 执行卷积层2 conv2\n",
    "        x = self.conv2(x)\n",
    "        # 将图像数据转为1维\n",
    "        x = x.view(x.size(0),-1)\n",
    "        # 执行全连接层 out\n",
    "        x = self.out(x)\n",
    "        return x\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dfa53c7",
   "metadata": {},
   "source": [
    "# 4.训练模型\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2039f4b1",
   "metadata": {},
   "source": [
    "#### 方法一 ：利用torchkeras包训练\n",
    "\n",
    "(需要torchkeras支持)  \n",
    "import from torchkeras import summary,Model \n",
    "\n",
    "1.实例化模型  \n",
    "```\n",
    "model = Model(自定义神经网络类名())\n",
    "```\n",
    "\n",
    "2.编译模型\n",
    "```\n",
    "model.compile(loss_func = 损失函数,\n",
    "             optimizer= 优化方法,\n",
    "             metrics_dict={\"accuracy\":accuracy})\n",
    "```\n",
    "\n",
    "3.训练模型\n",
    "```\n",
    "dfhistory = model.fit(训练次数,train_loader, test_loader, log_step_freq=100) \n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b593b836",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model(\n",
      "  (net): CNN(\n",
      "    (conv1): Sequential(\n",
      "      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (1): ReLU()\n",
      "      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    )\n",
      "    (conv2): Sequential(\n",
      "      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "      (1): ReLU()\n",
      "      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    )\n",
      "    (out): Linear(in_features=1568, out_features=10, bias=True)\n",
      "  )\n",
      ")\n",
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1           [-1, 16, 28, 28]             160\n",
      "              ReLU-2           [-1, 16, 28, 28]               0\n",
      "         MaxPool2d-3           [-1, 16, 14, 14]               0\n",
      "            Conv2d-4           [-1, 32, 14, 14]           4,640\n",
      "              ReLU-5           [-1, 32, 14, 14]               0\n",
      "         MaxPool2d-6             [-1, 32, 7, 7]               0\n",
      "            Linear-7                   [-1, 10]          15,690\n",
      "================================================================\n",
      "Total params: 20,490\n",
      "Trainable params: 20,490\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.002991\n",
      "Forward/backward pass size (MB): 0.323074\n",
      "Params size (MB): 0.078163\n",
      "Estimated Total Size (MB): 0.404228\n",
      "----------------------------------------------------------------\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "from torchkeras import summary,Model\n",
    "\n",
    "# 实例化模型\n",
    "model = Model(CNN())\n",
    "model = model.float()\n",
    "\n",
    "# 查看模型\n",
    "print(model)\n",
    "print(summary(model, input_shape=(1,28,28)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "69746c8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "def accuracy(y_pred,y_true):\n",
    "    y_pred_cls = torch.argmax(nn.Softmax(dim=1)(y_pred),dim=1).data\n",
    "    return accuracy_score(y_true.numpy(),y_pred_cls.numpy())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "71543cd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model.compile(loss_func = nn.CrossEntropyLoss(),\n",
    "             optimizer= torch.optim.Adam(model.parameters(),lr = 0.02),\n",
    "             metrics_dict={\"accuracy\":accuracy})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "57bbfa11",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start Training ...\n",
      "\n",
      "================================================================================2021-11-24 10:51:22\n",
      "{'step': 100, 'loss': 0.826, 'accuracy': 0.74}\n",
      "{'step': 200, 'loss': 0.537, 'accuracy': 0.833}\n",
      "{'step': 300, 'loss': 0.428, 'accuracy': 0.867}\n",
      "{'step': 400, 'loss': 0.372, 'accuracy': 0.884}\n",
      "{'step': 500, 'loss': 0.334, 'accuracy': 0.896}\n",
      "{'step': 600, 'loss': 0.305, 'accuracy': 0.905}\n",
      "{'step': 700, 'loss': 0.284, 'accuracy': 0.911}\n",
      "{'step': 800, 'loss': 0.272, 'accuracy': 0.915}\n",
      "{'step': 900, 'loss': 0.261, 'accuracy': 0.918}\n",
      "\n",
      " +-------+-------+----------+----------+--------------+\n",
      "| epoch |  loss | accuracy | val_loss | val_accuracy |\n",
      "+-------+-------+----------+----------+--------------+\n",
      "|   1   | 0.258 |  0.919   |  0.142   |    0.955     |\n",
      "+-------+-------+----------+----------+--------------+\n",
      "\n",
      "================================================================================2021-11-24 10:51:39\n",
      "{'step': 100, 'loss': 0.157, 'accuracy': 0.952}\n",
      "{'step': 200, 'loss': 0.162, 'accuracy': 0.951}\n",
      "{'step': 300, 'loss': 0.164, 'accuracy': 0.95}\n",
      "{'step': 400, 'loss': 0.165, 'accuracy': 0.95}\n",
      "{'step': 500, 'loss': 0.163, 'accuracy': 0.951}\n",
      "{'step': 600, 'loss': 0.162, 'accuracy': 0.951}\n",
      "{'step': 700, 'loss': 0.164, 'accuracy': 0.95}\n",
      "{'step': 800, 'loss': 0.163, 'accuracy': 0.951}\n",
      "{'step': 900, 'loss': 0.161, 'accuracy': 0.951}\n",
      "\n",
      " +-------+-------+----------+----------+--------------+\n",
      "| epoch |  loss | accuracy | val_loss | val_accuracy |\n",
      "+-------+-------+----------+----------+--------------+\n",
      "|   2   | 0.161 |  0.951   |  0.156   |    0.948     |\n",
      "+-------+-------+----------+----------+--------------+\n",
      "\n",
      "================================================================================2021-11-24 10:51:55\n",
      "{'step': 100, 'loss': 0.141, 'accuracy': 0.957}\n",
      "{'step': 200, 'loss': 0.137, 'accuracy': 0.958}\n",
      "{'step': 300, 'loss': 0.14, 'accuracy': 0.958}\n",
      "{'step': 400, 'loss': 0.141, 'accuracy': 0.957}\n",
      "{'step': 500, 'loss': 0.142, 'accuracy': 0.956}\n",
      "{'step': 600, 'loss': 0.142, 'accuracy': 0.956}\n",
      "{'step': 700, 'loss': 0.143, 'accuracy': 0.956}\n",
      "{'step': 800, 'loss': 0.145, 'accuracy': 0.955}\n",
      "{'step': 900, 'loss': 0.145, 'accuracy': 0.955}\n",
      "\n",
      " +-------+-------+----------+----------+--------------+\n",
      "| epoch |  loss | accuracy | val_loss | val_accuracy |\n",
      "+-------+-------+----------+----------+--------------+\n",
      "|   3   | 0.145 |  0.955   |   0.15   |    0.954     |\n",
      "+-------+-------+----------+----------+--------------+\n",
      "\n",
      "================================================================================2021-11-24 10:52:12\n",
      "Finished Training...\n"
     ]
    }
   ],
   "source": [
    "\n",
    "dfhistory = model.fit(3,train_loader, test_loader, log_step_freq=100) \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e561eb0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存模型 .save('模型.h5')\n",
    "\n",
    "torch.save(model,'pytorch_number_model.h5')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e460e1ba",
   "metadata": {},
   "source": [
    "#### 方法二：手动训练\n",
    "\n",
    "1.实例化模型\n",
    "```\n",
    "model变量 = 自定义神经网络类名()\n",
    "```\n",
    "\n",
    "2.定义损失函数和优化器\n",
    "```\n",
    "loss_func = nn.CrossEntropyLoss()\n",
    "\n",
    "optimizer = torch.optim.Adam(model变量.parameters(), lr = 学习率) \n",
    "```\n",
    "\n",
    "3.训练模型\n",
    "```\n",
    "for epoch in range(训练轮数):  # 训练轮数\n",
    "    \n",
    "    # 遍历训练集每条数据，进行训练 \n",
    "    running_loss = 0.0 # 记录一轮中每条训练数据预测的损失累加\n",
    "    for step, (x, y) in enumerate(训练集加载train_loader):   #【enumerate()枚举对象 得到格式（id，元素）】\n",
    "        b_x = Variable(x) # 数据x\n",
    "        b_y = Variable(y) # 标签y\n",
    "    \n",
    "        output = model_2(b_x) # 把数据输入进网络\n",
    "        loss = loss_func(output, b_y) # 计算一条数据的损失\n",
    "        running_loss += loss.item() # 损失累加\n",
    "        \n",
    "        optimizer.zero_grad() # 梯度置零\n",
    "        loss.backward()  # loss反向传播\n",
    "        optimizer.step() # 反向传播后参数更新\n",
    "        \n",
    "    \n",
    "    print('训练轮数：', epoch, ' 平均损失：',running_loss/len(train_loader)) #平均损失=每轮每条训练数据损失求和/每轮训练数据数\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "83c03e60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN(\n",
      "  (conv1): Sequential(\n",
      "    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (conv2): Sequential(\n",
      "    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU()\n",
      "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (out): Linear(in_features=1568, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# 初始化模型\n",
    "model_2 = CNN()\n",
    "\n",
    "# 查看模型\n",
    "print(model_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d6d52593",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 损失函数\n",
    "loss_func = nn.CrossEntropyLoss()\n",
    "\n",
    "# 优化器\n",
    "optimizer = torch.optim.Adam(model_2.parameters(), lr = 0.02) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1a3c953b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练轮数： 0  训练平均损失loss： 0.15682169912543806\n",
      "训练轮数： 0  测试平均损失val_loss： 0.07712881063551842\n",
      "训练轮数： 1  训练平均损失loss： 0.08413001222844754\n",
      "训练轮数： 1  测试平均损失val_loss： 0.06314141603382138\n",
      "训练轮数： 2  训练平均损失loss： 0.07663096440670543\n",
      "训练轮数： 2  测试平均损失val_loss： 0.07487657513619135\n",
      "end\n"
     ]
    }
   ],
   "source": [
    "from torch.autograd import Variable\n",
    "\n",
    "EPOCH = 3 # 训练轮数\n",
    "\n",
    "#记录用于绘图\n",
    "losses = []#记录每次迭代后训练的loss\n",
    "eval_losses = []#测试的\n",
    "\n",
    "\n",
    "\n",
    "# 自定义训练方法\n",
    "\n",
    "for epoch in range(EPOCH):  # 训练轮数\n",
    "    # 遍历训练集每条数据，进行训练，得到每轮损失loss\n",
    "    running_loss = 0.0\n",
    "    for step, (x, y) in enumerate(train_loader):   #【enumerate()枚举对象 得到格式（id，元素）】\n",
    "        b_x = Variable(x) # 数据x\n",
    "        b_y = Variable(y) # 标签y\n",
    "    \n",
    "        output = model_2(b_x) # 把数据输入进网络\n",
    "        loss = loss_func(output, b_y) # 计算一条数据的损失\n",
    "        running_loss += loss.item() # 损失累加\n",
    "        \n",
    "        optimizer.zero_grad() # 梯度置零\n",
    "        loss.backward()  # loss反向传播\n",
    "        optimizer.step() # 反向传播后参数更新\n",
    "    \n",
    "    losses.append(running_loss/len(train_loader)) # 记录该轮平均损失，后续用于画图\n",
    "    print('训练轮数：', epoch, ' 训练平均损失loss：',running_loss/len(train_loader)) #平均损失=每轮每条训练数据损失求和/每轮训练数据数\n",
    "    \n",
    "    \n",
    "    \n",
    "    # 遍历测试集每条数据，进行测试，每轮训练后损失val_loss\n",
    "    running_loss = 0.0\n",
    "    for step, (x, y) in enumerate(test_loader): \n",
    "        b_x = Variable(x) # 数据x\n",
    "        b_y = Variable(y) # 标签y\n",
    "        \n",
    "        output = model_2(b_x) # 把数据输入进网络\n",
    "        loss = loss_func(output, b_y) # 计算一条数据的损失\n",
    "        running_loss += loss.item() # 损失累加\n",
    "    \n",
    "    eval_losses.append(running_loss/len(test_loader)) # 记录该轮平均损失，后续用于画图\n",
    "    print('训练轮数：', epoch, ' 测试平均损失val_loss：',running_loss/len(test_loader)) #平均损失=每轮每条训练数据损失求和/每轮训练数据数\n",
    "    \n",
    "        \n",
    "    \n",
    "print('end') \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9dc3433e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAwkElEQVR4nO3deXxU9bnH8c8zkxWyAdkg7LLJliCBRG0p1apYbV1qC6goe21rtd6WXtvazdpq1dtarVcvmyubCy6tWqy7VRMJEEQMKCBLICEhLAmQhCzP/eNMyGQjE0gyyeR5v17zIjnLzJPx+D2/+Z3f+Y2oKsYYYwKXy98FGGOMaVsW9MYYE+As6I0xJsBZ0BtjTICzoDfGmABnQW+MMQHOgt50eSKyU0S+4e86jGkrFvTGGBPgLOiNMSbAWdAb4yEioSLygIjs8zweEJFQz7pYEfmniBwWkYMi8r6IuDzr/ltE9opIiYhsFZEL/fuXGFNXkL8LMKYD+RWQDqQACrwE3AH8GvgpkAvEebZNB1REhgM3AxNUdZ+IDATc7Vu2MadmLXpjal0H3KmqBapaCPwemOFZVwH0BgaoaoWqvq/ORFFVQCgwUkSCVXWnqm73S/XGNMGC3phafYBdXr/v8iwDuA/YBrwuIjtE5HYAVd0G/AT4HVAgIitFpA/GdCAW9MbU2gcM8Pq9v2cZqlqiqj9V1cHAt4D/qumLV9XlqvoVz74K/Ll9yzbm1Czojam1ArhDROJEJBb4DfA0gIhcLiJDRESAYpwumyoRGS4iF3gu2pYBpZ51xnQYFvTG1LoLyAI+ATYB6z3LAIYCbwBHgY+A/1XVd3D65+8BDgD5QDzwy3at2phmiH3xiDHGBDZr0RtjTICzoDfGmABnQW+MMQHOgt4YYwKcT1MgiMgU4G84t3YvVtV76q0fATwGnAP8SlXv91oXAywGRuOMMZ6tqh+d6vViY2N14MCBvv8VxhjTxa1bt+6AqsY1tq7ZoBcRN/AwcBHOXB9rReRlVf3Ma7ODwC3AlY08xd+Af6nqNSISAnRr7jUHDhxIVlZWc5sZY4zxEJFdTa3zpetmIrBNVXeo6glgJXCF9waeuUHW4swH4v3CUcAkYIlnuxOqerhl5RtjjDkTvgR9ErDH6/dczzJfDAYKgcdEZIOILBaR7i2s0RhjzBnwJeilkWW+3mUVhNNv/4iqjgOOAbc3+iIi80UkS0SyCgsLfXx6Y4wxzfHlYmwu0M/r9754Jnrycd9cVc30/P4cTQS9qi4EFgKkpqba7brGdCIVFRXk5uZSVlbm71ICXlhYGH379iU4ONjnfXwJ+rXAUBEZBOwFpgHX+vLkqpovIntEZLiqbgUuBD5rbj9jTOeSm5tLZGQkAwcOxJn3zbQFVaWoqIjc3FwGDRrk837NBr2qVorIzcAanOGVS1V1s4jc5Fn/qIgk4kwGFQVUi8hPgJGqWgz8GFjmGXGzA5jVwr/NJy9u2Mt9a7ay73ApfWLCWXDJcK4c5+ulBGPMmSgrK7OQbwciQq9evWhp97ZP4+hV9VXg1XrLHvX6OR+nS6exfbOB1BZV1UIvbtjLL1ZvorTCmR127+FSfrF6E4CFvTHtxEK+fZzO+xwQd8bet2bryZCvUVpRxX1rtvqpImOM6TgCIuj3HS5t0XJjTOCJiIjwdwkdlk9dNx1dn5hw9jYS6n1iwv1QjTGmOXZNrX0FRIt+wSXDCQ92N1h+fXp/P1RjjDmVmmtqew+XotReU3txw95WeX5VZcGCBYwePZoxY8awatUqAPLy8pg0aRIpKSmMHj2a999/n6qqKmbOnHly27/+9a8AbN++nSlTpjB+/Hi++tWvsmXLFgCeffZZRo8eTXJyMpMmTWqVettDQLToa1oCNS2E+KhQSk9Usvj9L7nw7ASGJUT6uUJjuo7f/2Mzn+0rbnL9ht2HOVFVXWdZaUUVP3/uE1Z8vLvRfUb2ieK33xrl0+uvXr2a7OxsNm7cyIEDB5gwYQKTJk1i+fLlXHLJJfzqV7+iqqqK48ePk52dzd69e/n0008BOHz4MADz58/n0UcfZejQoWRmZvLDH/6Qt956izvvvJM1a9aQlJR0ctvOICCCHpyw9/7o9+WBY0z9v4+4dlEGK+alM9TC3pgOoX7IN7e8pf7zn/8wffp03G43CQkJfO1rX2Pt2rVMmDCB2bNnU1FRwZVXXklKSgqDBw9mx44d/PjHP+ayyy7j4osv5ujRo3z44Yd897vfPfmc5eXlAJx//vnMnDmT733ve1x99dWtUm97CJigr29QbHdWzE9n+sIMpi/KZOX8NIbEW9gb09aaa3mff89bjV5TS4oJZ9X3zz3j12/qe7AnTZrEe++9xyuvvMKMGTNYsGABN9xwAxs3bmTNmjU8/PDDPPPMMzzwwAPExMSQnZ3d4DkeffRRMjMzeeWVV0hJSSE7O5tevXqdcc1tLSD66JtyVlwEy+elIwLTFmayreCov0sypstr7JpaeLCbBZcMb5XnnzRpEqtWraKqqorCwkLee+89Jk6cyK5du4iPj2fevHnMmTOH9evXc+DAAaqrq/nOd77DH/7wB9avX09UVBSDBg3i2WefBZwTx8aNGwGn7z4tLY0777yT2NhY9uzZc6pSOoyADnqAIfERrJiXBsD0RRlsL7SwN8afrhyXxN1XjyEpJhzBacnfffWYVht1c9VVVzF27FiSk5O54IILuPfee0lMTOSdd94hJSWFcePG8fzzz3Prrbeyd+9eJk+eTEpKCjNnzuTuu+8GYNmyZSxZsoTk5GRGjRrFSy+9BMCCBQsYM2YMo0ePZtKkSSQnJ7dKzW1NmvqY40+pqana2l888sX+EqYvysAlwsr56QyOszG3xrSWnJwczj77bH+X0WU09n6LyDpVbXQWgoBv0dcYmhDJ8nnpVFUr0xdl8OWBY/4uyRhj2kWXCXqAYZ6wr6hSpi/MYKeFvTGmC+hSQQ8wPDGS5fPSKK+sYvqiDHYVWdgbYwJblwt6gBGJUSybm05ZRRXTF2awu+i4v0syxpg20yWDHpw77Z6em8bxCqdlv+eghb0xJjB12aAHGNUnmqfnpHG0vJJpCzPIPWRhb4wJPF066AFGJ0WzbG4aJWUVTFuY0egde8YY05l1+aCHmrBPp7i0gmkLP7J57I1pDyX58NilULLf35WcNHDgQA4cONDk+s46570FvceYvtE8NSeNw8edln3eEQt7Y9rUu/fC7gx498/+riTgBeykZqcjuV8MT81JY8biTKYtzGDV/HNJjA7zd1nGdC6v3Q75m5pev/sD8L4jP2uJ8xCB/uc3vk/iGLj0nlO+7NNPP82DDz7IiRMnSEtLY+zYsezatYt7770XgMcff5x169bx0EMPceWVV7Jnzx7Kysq49dZbmT9/fov+RFXl5z//Oa+99hoiwh133MHUqVPJy8tj6tSpFBcXU1lZySOPPMJ5553HnDlzyMrKQkSYPXs2t912G9u3b+dHP/oRhYWFdOvWjUWLFjFixAieffZZfv/73+N2u4mOjua9995rUW2NsRZ9PSn9YnhyzkSKjp5g+qIM8o+U+bskYwJLnwnQLQ7EEz/igu5xkDThtJ8yJyeHVatW8cEHH5CdnY3b7SYiIoLVq1ef3GbVqlVMnToVgKVLl7Ju3TqysrJ48MEHKSoqatHrec95/8Ybb7BgwQLy8vJOznlfs65mhsuaOe83bdrErFmzAGfO+4ceeoh169Zx//3388Mf/hDg5Jz3Gzdu5OWXXz7t98SbtegbMa5/D56YPZEbl37szGc/P52EKGvZG+OTZlreAPzjNlj/OASFQdUJOPvbcPlfTvsl33zzTdatW8eECc7JorS0lPj4eAYPHkxGRgZDhw5l69atnH++84nhwQcf5IUXXgBgz549fPHFFy2abrizzXlvLfomjB/QgydmT2B/cRnTF2VQUGwte2NazbECGD8L5r7h/Hv0zC7Iqio33ngj2dnZZGdns3XrVn73u98xdepUnnnmGZ5//nmuuuoqRIR33nmHN954g48++oiNGzcybtw4yspa9v93c3PeJyUlMWPGDJ588kl69OjBxo0bmTx5Mg8//DBz586lurr65Jz3NY+cnBzAmfP+rrvuYs+ePaSkpLT400aTBXe0x/jx47WjWPtlkZ7969f0gvvf1v3Fpf4ux5gO6bPPPvPr62/evFmHDBmi+/fvV1XVoqIi3blzpx48eFAHDRqkkydP1szMTFVVffHFF/Xyyy9XVdWcnBwNDQ3Vt99+W1VVBwwYoIWFhU2+Tvfu3VVV9fnnn9eLL75YKysrtaCgQPv37695eXm6c+dOraioUFXVv/71r3rrrbdqYWGhHjlyRFVVN2zYoMnJyaqqeu655+ozzzyjqqrV1dWanZ2tqqrbtm07+XopKSm6YcOGBnU09n4DWdpEplqLvhmpA3vy+KyJ5B0p49pFmRSWlPu7JGNMPSNHjuSuu+7i4osvZuzYsVx00UXk5eXRo0cPRo4cya5du5g4cSIAU6ZMobKykrFjx/LrX/+a9PT0Fr9eZ5vzvsvMR3+mMncUMfOxtfTtEc6K+enERoT6uyRjOgybj7592Xz0bSRtcC8emzWB3EOlXLsog6Kj1rI3xnQOFvQtkD64F0tmprL74HGuW5xpYW9MACoqKiIlJaXBo1UuivqJDa9sofPOimXpjROY9fharlucyfJ56fTsHuLvsozxO1VFRPxdxhnr1asX2dnZ/i6jSafT3W4t+tNw3pBYltw4gS8PHOO6xZkcOnbC3yUZ41dhYWEUFRWdVggZ36kqRUVFhIW17L4euxh7Bt7/opA5T2QxJC6C5fPSiOlmLXvTNVVUVJCbm9vi8eim5cLCwujbty/BwcF1lp/qYqwF/Rl69/NC5j2ZxdD4CJbNtbA3xviHjbppQ18bFsfCGeP5Yv9RZiz5mCPHK/xdkjHG1GFB3womD4/n/2aMZ2t+CTOWZnKk1MLeGNNx+BT0IjJFRLaKyDYRub2R9SNE5CMRKReRnzWy3i0iG0Tkn61RdEf09RHxPDrjHHLyirlhSSbFZRb2xpiOodmgFxE38DBwKTASmC4iI+ttdhC4Bbi/iae5Fcg5gzo7hQtGJPDIdeP5LK+YG5Z8bGFvjOkQfGnRTwS2qeoOVT0BrASu8N5AVQtUdS3QINlEpC9wGbC4Fert8L4xMoGHrz2HT/ce4calH1NiYW+M8TNfgj4J2OP1e65nma8eAH4OVJ9qIxGZLyJZIpJVWFjYgqfveC4elcjfrz2HTblO2B8tr/R3ScaYLsyXoG/sVjefxmSKyOVAgaqua25bVV2oqqmqmhoXF+fL03doU0Yn8vdrx7Ex9wgzLeyNMX7kS9DnAv28fu8L7PPx+c8Hvi0iO3G6fC4QkadbVGEnNmV0bx6aPo4New4z67GPOWZhb4zxA1+Cfi0wVEQGiUgIMA3w6YsMVfUXqtpXVQd69ntLVa8/7Wo7oW+O6c2D08axfvdhZj221sLeGNPumg16Va0EbgbW4IyceUZVN4vITSJyE4CIJIpILvBfwB0ikisiUW1ZeGdy2djePDA1haxdB5n9+FqOn7CwN8a0H5sCoR29lL2X21ZlM3FQTx6bOZHwELe/SzLGBAibAqGDuCIlib9OTeHjLw8y54m1lJ6o8ndJxpguwIK+nV2RksT/fC+Zj3YUMffJtZRVWNgbY9qWBb0fXDWuL/dfk8yH24uY92SWhb0xpk1Z0PvJd8b35b5rkvnPtgMW9saYNmVB70fXjO/Ln78zlv9sO8D3n1pnYW+MaRMW9H72vdR+3HP1GN79vJCbnl5HeaWFvTGmdVnQdwBTJ/Tn7qvH8M7WQn7w9HoLe2NMq7Kg7yCmT+zPn64aw1tbCvihhb0xphVZ0Hcg16b1564rR/PmlgJ+tGwDJypPOeGnMcb4xIK+g7k+fQB/uGIUb+Ts50fL11vYG2POmAV9BzTj3IH8/tuj+Pdn+/nxivVUVFnYG2NOnwV9B3XjeQP57bdGsmbzfm5ZscHC3hhz2izoO7BZ5w/i15eP5LVP87l1pYW9Meb0BPm7AHNqc74yCFXlrldyEMnmb1NTCHLb+dkY4zsL+k5g7lcHowp/fDUHlwh//V6yhb0xxmcW9J3EvEmDqVbl7te2IMBfLOyNMT6yoO9Evv+1s6hW+PO/tuAS+J/vpeB2Nfbd7cYYU8uCvpP5weSzqFblvjVbcYlw33eTLeyNMadkQd8J/ejrQ1BV7n/9cxC47xoLe2NM0yzoO6mbLxhKtcJf/v05LhHu/c5YXBb2xphGWNB3YrdcOJRqVR544wtcAvdcbWFvjGnIgr6T+8k3hlGt8OCbX+AS4U9XjbGwN8bUYUEfAG77xlBUlYfe2oYI/PFKC3tjTC0L+gAgIvzXRcOoVuXht7cjItx1xWgLe2MMYEEfMESEn108nGqFR97ZjkvgD1eMRsTC3piuzoI+gIgIP79kONWq/N+7O3CJ8Ptvj7KwN6aLs6APMCLC7VNGoAoL33PC/rffGmlhb0wXZkEfgESEX1w6gupqZfF/vkQEfnO5hb0xXZUFfYASEX512dlUKyz94EsE4deXn21hb0wXZEEfwESccK9WZekHX+IS+NVlFvbGdDUW9AFOPH30qk43jsvldOtY2BvTdVjQdwEiwu++PYpqzwVaEbh9ioW9MV2FBX0XISLcecUolNqhlz+/ZLiFvTFdgE9fUSQiU0Rkq4hsE5HbG1k/QkQ+EpFyEfmZ1/J+IvK2iOSIyGYRubU1izctIyLc+e3RXJvWn0fe2c79r29FVf1dljGmjTXbohcRN/AwcBGQC6wVkZdV9TOvzQ4CtwBX1tu9Evipqq4XkUhgnYj8u96+ph25XM70COqZLsHlmT7BWvbGBC5fum4mAttUdQeAiKwErgBOhrWqFgAFInKZ946qmgfkeX4uEZEcIMl7X9P+XC7hj1eOoboaHnprGy4RbrtomL/LMsa0EV+CPgnY4/V7LpDW0hcSkYHAOCCzifXzgfkA/fv3b+nTmxZyuYS7rx5DtSp/e/MLRJwpj40xgceXoG/sM32LOnZFJAJ4HviJqhY3to2qLgQWAqSmplrHcTtwuYQ/f2csCp4vLxFuuXCov8syxrQyX4I+F+jn9XtfYJ+vLyAiwTghv0xVV7esPNPWasK+WtXztYTO1xQaYwKHL0G/FhgqIoOAvcA04FpfnlycK3xLgBxV/ctpV2nalNsl3HdNMqpw/+ufIyL86OtD/F2WMaaVNBv0qlopIjcDawA3sFRVN4vITZ71j4pIIpAFRAHVIvITYCQwFpgBbBKRbM9T/lJVX231v8ScEbdLuP+7yVSrct+arbhE+MHks/xdljGmFfh0w5QnmF+tt+xRr5/zcbp06vsPjffxmw7I7RL+57tOy/7P/9qCS+D7X7OwN6azsztjTR1Bbhd/+Z7Tsr/7tS24RJg3abC/yzLGnAELetNAkNvFA1NTUIU/vpqDCMz9qoW9MZ2VBb1pVJDbxQPTUlCUu17JwSXC7K8M8ndZxpjTYEFvmhTsdvG3aeOort7Anf/8DJfAzPMt7I3pbHya1Mx0XcFuFw9dO45LRiXwu398xpMf7fR3ScaYFrKgN80Kdrt4aPo5XDQygd+8tJmnLOyN6VQs6I1PQoJcPHztOXzj7Hh+/dJmns7Y5e+SjDE+sqA3PgsJcvHwdedw4Yh47njxU5Zn7vZ3ScYYH1jQmxYJDXLzv9efw9eHx/HLFzax8mMLe2M6Ogt602KhQW4euX48k4fHcfvqTTyzdk/zOxlj/MaC3pyWsGA3j14/nknD4vjv1Z/wTJaFvTEdlQW9OW1hwW4WzhjPV4bE8t/Pf8Jz63L9XZIxphEW9OaMhAW7WXRDKuefFcuC5zayer2FvTEdjQW9OWM1YX/eWb346bMbeWGDhb0xHYkFvWkV4SFuFt8wgfRBvfjpMxt5KXuvv0syxnhY0JtWEx7iZsnMVCYO6sltq7J5eaPP3zhpjGlDFvSmVXULCWLpzAmkDuzJT1Zu4B8W9sb4nQW9aXXdQoJ4bOYEUgf05Cersnnlkzx/l2RMl2ZBb9pE99AgHps1gXP6x3DLyg28tsnC3hh/saA3bcYJ+4mk9Ivhxys28K9P8/1dkjFdkgW9aVMRoUE8PmsCY/tGc/Py9by+2cLemPZmQW/aXGRYME/MnsjopGh+tHw9//5sv79LMqZLsaA37SIyLJgn50xkZJ9ofrhsHW/mWNgb014s6E27iQoL5snZEzm7dxQ/eHo9b22xsDemPVjQm3YVHR7MU7PTGJ4YyU1PreftrQX+LsmYgGdBb9pddLdgnp6TxrDECL7/1DresbA3pk1Z0Bu/qAn7IXERzH9qHe99XujvkowJWBb0xm9iuoWwbG4aZ8VFMO/JLN7/wsLemLZgQW/8qkd3J+wHxXZn7hNZfLDtgL9LMibgWNAbv+vZPYTl89IZFNudOU+s5UMLe2NalQW96RB6elr2A3p2Z/YTa/loe5G/SzImYFjQmw6jV0Qoy+al0a9HN2Y/vpaMHRb2xrQGC3rTocRGhLJ8XjpJPcKZ9dhaMi3sjTljPgW9iEwRka0isk1Ebm9k/QgR+UhEykXkZy3Z15j64iJDWT4vjT4xYcx6fC1rdx70d0nGdGrNBr2IuIGHgUuBkcB0ERlZb7ODwC3A/aexrzENxEeGsWJeOonRYcxc+jFZFvbGnDZfWvQTgW2qukNVTwArgSu8N1DVAlVdC1S0dF9jmhIfFcbKeekkRIVx49KPWbfLwt6Y0+FL0CcBe7x+z/Us84XP+4rIfBHJEpGswkK7ccY44qPCWDE/nfioMG5cupb1uw/5uyRjOh1fgl4aWaY+Pr/P+6rqQlVNVdXUuLg4H5/edAUJUU43TmxECDcu+ZgNFvbGtIgvQZ8L9PP6vS+wz8fnP5N9jTkpMdpp2feMCOGGJR+Tveewv0syptPwJejXAkNFZJCIhADTgJd9fP4z2deYOnpHh7NiXjo9uocwY0kmn+Qe9ndJxnQKzQa9qlYCNwNrgBzgGVXdLCI3ichNACKSKCK5wH8Bd4hIrohENbVvW/0xJvD1iQlnxfx0YroFc/3iTDblHvF3ScZ0eKLqa3d7+0lNTdWsrCx/l2E6sNxDx5m2MIOSskqWzU1jdFK0v0syxq9EZJ2qpja2zu6MNZ1S3x7dWDEvnYjQIK5bnMmne61lb0xTLOhNp9WvZzdWzk+ne4ib65dk8tm+Yn+XZEyHZEFvOjUn7M8lPNjNdYszyMmzsDemPgt60+n17+W07EOD3Fy3OJMt+Rb2xnizoDcBYUCv7qycn06wW7h2USZb80v8XZIxHYYFvQkYA2O7s3L+uQS5hGsXZfD5fgt7Y8CC3gSYQbHdWTE/Hbcn7L+wsDfGgt4EnrPiIlg+Lx0RYfqiTLYVHPV3Scb4lQW9CUhD4iNYMS8NgOmLMtheaGFvui4LehOwhsRHsmJeGqrK9IUZ7LCwN12UBb0JaEMTIlk+L52qamX6ogy+PHDM3yUZ0+4s6E3AG+YJ+4oqp2W/08LedDEW9KZLGJ4YyfJ5aZRXVjF9UQa7iizsTddhQW+6jBGJUSybm05ZRRXTF2awu+i4v0sypl1Y0JsuZWSfKJ6em8bxCqdlv+eghb0JfBb0pssZ1Seap+ekcbS8kmkLLexN4LOgN13S6KRols1No6SsgumLMsg9ZGFvApcFvemynLBPp7jUCfu9h0v9XZIxbcKC3nRpY/pG89ScNA4fr2D6wgz2WdibAGRBb7q85H4xPDUnjUPHTjB9UQZ5RyzsTWCxoDcGSOkXw5NzJlJ09ATTF2aQf6TM3yUZ02os6I3xGNe/B0/MnsiBo07Lfn+xhb0JDBb0xngZP6AHT8yeQEFxGdMXZlBgYW8CgAW9MfWMH9CTJ2ZPJL+4jGmLMigosbA3nZsFvTGNSB3Yk8dnTST/iNOyLywp93dJxpw2UVV/19BAamqqZmVl+bsMY8jcUcTMx9YSFRaEiLC/uIw+MeEsuGQ4V45L8nd5xpwkIutUNbWxddaiN+YU0gb3Ys5XBrK/pJz84jIU2Hu4lF+s3sSLG/b6uzxjfBLk7wKM6ehe2LCvwbLSiip+/twnvJGzn4SoMBKjwoiPCiUxKozE6DASosIIC3b7oVpjGrKgN6YZTd0te6Kqms37inkzp4DSiqoG66PDg+ucABKiwkiIDvP87CzrFRGK2yVt/SeYLs6C3phm9IkJb3QenKSYcN7+2WRUleKySgqKy8gvLmN/cTn7i8vIP1LG/mLn8fn+EgpLyqmud0nM7RLiIkJJiA4jITL05KeBmk8JCVHOushQ5xqBMafDgt6YZiy4ZDi/WL2pTqs9PNjNgkuGAyAiRIcHEx0ezNCEyCafp6paOXC0/kmg3HNyKGNn0TEydhRRXFbZYN9uIW7PCcDr00HNCSE6lISoMOIjwwgJsstupiELemOaUTO65r41W9l3uPS0R924XXIynMf2bXq70hNVzsmguPYTwckTwpEy1u0+xP4j5Zyoqm6wb6/uIcRHhZEY5Xw6iI90rhl4dyH16BaCy7qLuhQbXmlMJ6SqHDpecfKEUFBcRv6R8tqfPSeIA0dPNNg32C0nTwAJUaFe3URhtZ8aosPoFmLtwM7kVMMrffovKSJTgL8BbmCxqt5Tb7141n8TOA7MVNX1nnW3AXMBBTYBs1TVbjU05gyICD27h9Czewhn945qcruKqmoKSpzuov1HGl5D2JJfwrtbCzl2ouHF5MiwoHongYafEmIjQghyW3dRR9ds0IuIG3gYuAjIBdaKyMuq+pnXZpcCQz2PNOARIE1EkoBbgJGqWioizwDTgMdb9a8wxjQq2O0iKSacpJjwU253tLyyzsXjmm6imi6jHdsPUFBSTmW9q8kugdgI7xNAKAmR3qOLnH+jwu1isj/50qKfCGxT1R0AIrISuALwDvorgCfV6QfKEJEYEent9RrhIlIBdAMaDko2xvhVRGgQQ+IjGBIf0eQ21dXKgWPlFBSXOyeFkrqfEnIPHSdr10EOH69osG9YsKvuBWRPl1FCzX0Hkc41BLv3oG34EvRJwB6v33NxWu3NbZOkqlkicj+wGygFXlfV1xt7ERGZD8wH6N+/v2/VG2Pajcvl9O3HR4YxOim6ye3KKqooKC5nf0lZvU8J5ew/UsYnuYd5/UgZ5ZUNLyb36Bbc6PDShMjaG9F6dbeLyS3lS9A39o7Wv4Lb6DYi0gOntT8IOAw8KyLXq+rTDTZWXQgsBOdirA91GWM6oLBgN/17daN/r25NbqOqFJdWku89usjzKSH/iHMNISevmMKj5dQfLxLkEuIj654ATt6VHBXmjDqKDiMi1C4m1/DlncgF+nn93peG3S9NbfMN4EtVLQQQkdXAeUCDoDfGdB0iQnS3YKK7BTM8sel7Dyqrqik8Wu5cKzhSRoHnU4IzuqicbYVH+WDbAUrKG957EBEaVPeuZO8uI8+ng/jIUIK7wMVkX4J+LTBURAYBe3Eupl5bb5uXgZs9/fdpwBFVzROR3UC6iHTD6bq5EGi7cZMl+fDcLLjmcYhMaLOXMca0jyC3i97R4fSODq/blKznWHnlyfsN6t+DkH+kjI+/PEhBSRkVVXU/HohAr+6htTeinfyUEOq5H8F5xHQL7tQXk5sNelWtFJGbgTU4wyuXqupmEbnJs/5R4FWcoZXbcIZXzvKsyxSR54D1QCWwAU/3TJt4917YnQHv/hku/0ubvYwxpmPpHhrE4LgIBsed+mLyoeMnvE4C5XWuIew7Ukb2nsMUHWt470FIkOvkySDe+/qB1/DTxOjTn8juxQ17z/iGvFMJjBum7oqHyka+GCIoFO4oaL3CjDEBr7zSuZhc4HW9oO6nBOcE0dhEdlFhQQ3nK6o3j1FsvYnsXtywt9EpNu6+ekyLwv5UN0wFRtCX5MO/fgmbn69dFhIJgydD/zRIHAu9x0J4j1av1RjT9agqJeWVde418O4qqjkhFB4tp6revQcnJ7LzfCL4YNuBRm9YS4oJ54PbL/C5pjO+M7bDi0yEsChAwBUE1ZVOqO/bAFv+UbtdzAAn8HsnQ+8U5wRgffnGmBYSEaLCgokKa34iu6KjzonAufeg3HNycD4h7Cw61mjIQ9PTY5+OwAh6gGOFkDobUmdB1mNwdD9MWwbHDkDeRsj/xPk3byPkeIV/RKIn+D0ngMSxENPfuUpjjDFnwO0S4j39+k1NZHf+PW+x93ApcRzi7yEPcfOJWygkhj7N3M3cEoHRddNSZcWQv6nuCaBwK6jnzBoW4xX+Kc7PPc8CV+APwzLGtK+aPvpf6iKuc7/JsqoL+ZPMsz76NlFRCvs/g7zs2vDfvxmqPFfgg7tD4pi6rf+4EeAObt86jTGBo7Ic7u5bmzPeWjiYJPD76FtDcDj0He88alRVOC39mi6f/E9gw9Pw8TFnvTsE4kd6wt/zSBjlPJcxxtQ4cQwOfO7kSeEWKPzc+ffQl6D1poIICoezL4eL/9hqL29BfyruYEgc7TzGXecsq66Gg9trwz9vI+S8DOufcNaLG+KGe0b6eFr/iWMgrOm5QYwxAaL0sCfQt3hC3fM4srt2G1cQ9Bri5Mro7zh5kfMPJ0fcIVBVDqFRrTpQxIK+pVwuiB3qPMZc4yxThSN7PMHv6fb58l34ZGXtfj0He4W/59E91j9/gzHm9Kk6gzwObG0Y6Efza7cLCnNyon8axN7gBHrcCOg5qGGX7+YXYPysuoNJWpH10belkv11R/vkbYTDu2rXRyXVjvSpCf+oPjbix5iOQBWK93kCvV6XS+nB2u1CIiFumBPiccMhdrjzb0x/cLXftMvWR+8vkQkQeREMvah2Wemh2hE/Na3/ra9xckLQbr1qQ7/mBNBjkI34MaatVFc7DbDGulxOlNRuF97DCfOR33b+jfWEeydonFnQt7fwHjBokvOoceKYM8Inb6Mz6ifvE/jw71Dt+QKHkMi64/x7JzsHmdv+8xnjs6oKOPilE+berfQD26DS6+akiESnhZ4y3auFPsLpau3ggd4US4qOIKQ79JvoPGpUlkNBjlfXzydO313NARkU5ozw8W79x4+E4DD//A3GdBQVZVC0zRPiXq30ou21jSeA6P5OoA/6mlegDwvIqVIs6DuqoFDok+I8alRXwYEv6vb7b3oespY6611BEHe21zQPyZAwGkKbntHPmE6r/GjtkEXvFvqhnbVDFsXldH3GDYfhl9Z2ucQO61L/X9jF2M5O1Tmw60/zcKzQs4E4Q7nqT/PQrac/qzbGd6WHai+CerfQj3h9e6kr2DnO618U7TWky3zKtYuxgUzEGa7VcxCMutJZpurM6Ol9o9eeTPj0udr9ovvXbfn3TnYmhzPGH2qGLBZuaRjo3kMNg8I9QxbPhbgbPS304Y0PWTQnWdAHIhGI6u08hk+pXX78YN3wz9sIW/5Zu757vFfwe04CMQM67QUo0wHVDFmsCXHvLpfSQ7XbhUQ6rfIhF9VtpUf3txFop8GCvivp1hPO+rrzqFFeAvmf1j0BbH/La4K3aK9x/inOCaDXkHYdH2w6oeoqZ8hi4ef1Qv3zekMWe3qGLF7puaHIM8Ilsrc1MFqRBX1XFxoJA851HjUqyqBgc+04/7yN8PEi59ZsgOBuzkVe79Z/3NkQFOKfv8H4T1UFHNzhNfbcM3TxwBdQWVa7XUSiE+Ip13q10EfY3eHtxILeNBQcBknjnUeNqgqn39T7Rq+NK2HtIme9KxgSRtZt/SeMgpBufvkTTCurKIOiL+oF+ufOMMbqytrtYvo7feY1QxZrRrmEx/itdGOjbsyZqK52Zt+rucmrpvVfc3u4uJz/yb1v9EocY//Td2TlR2u7WLy7XOoPWew5uPZW/7gRTis9dphzT4jxC5uP3rQfVSje6zW/j+cEULKvdpseA71u9PL8GxHnt5K7pNJDdW/1r2mh1x+yGDu09lb/mj70XkOc+zxMh2LDK037EYHovs5jxGW1y48WQv7GuieAz16qXR/Zu+EcP9F97YLcmVB17qeoM3+LJ9DrD1mMG+YZsjizNtR7DLJpNgKE/Vc07SMiDoZ8w3nUKDviNcGbJ/y/eL22iyC8Z92x/onJTpeBDa+rq+ZTlPfsijVdLt5DFkOjvIYsenW52JDFgGdBb/wnLBoGfsV51Dhx3Jngzbv1n/FI7VethUR4WvxeJ4DY4V2j5XlyyGIjXS4njtZu162X15BFry4XG7LYZXWB/ztMpxLSDfpNcB41Kk84geZ9o9f6J6HiuLPeHeo1wZvnBBA/qvPe+n5yyGK9KXOL6g1ZjOztBPi46+v2o9uQRVOPBb3p+IJCPAE+tnZZdZUzG2HN1M75n8Dm1bDuMWe9uJ3g8x7rnzjGuW+go6go9cyyWO+big5ubzhkMW4EnDW5dsrcuGH29ZTGZzbqxgQOVTi8u+5dvvuy4ViBZwOBXmfV+z7fZOjeq23rKi9p4ouhd3LyC2fE7czXUv9bimKH2pBF4xMbdWO6BhHoMcB5jPx27fKSfK9x/tmwN8tp/deI7lfv+3zHNt6fXZIPz82Cax5v/Iubjx9s/FuKinNrt3GHOMMT+6RA8rTaLpdeZ9mQRdNmLOhN4ItMdB7DLq5ddvygp7/f60avra9ysoXdPa7h9/l++CDszoA3fuvcyl+/y+XkJwecaSJih8HA8+t+S1GPgV3jwrHpUKzrxpga5Udh/6d1w78wp25/eX2h0Z5RLV7zt8QOcz4l2JBF046s68YYX4RGQP9051Gjshx2vg9v/8np79cq547RAefBxXc5F3htyKLp4KzJYcypBIU6N3klJgPqfFevVnm+tWushbzpFKxFb4wvjhXA+FmQOsv5knbvKQSM6eB8CnoRmQL8DXADi1X1nnrrxbP+m8BxYKaqrvesiwEWA6NxrnTNVtWPWusPMKZdTFtW+/Plf/FfHcachma7bkTEDTwMXAqMBKaLyMh6m10KDPU85gOPeK37G/AvVR0BJAM5rVC3McYYH/nSRz8R2KaqO1T1BLASuKLeNlcAT6ojA4gRkd4iEgVMApYAqOoJVT3ceuUbY4xpji9BnwR4TVJNrmeZL9sMBgqBx0Rkg4gsFhG7zc8YY9qRL0Hf2LCC+oPvm9omCDgHeERVxwHHgNsbfRGR+SKSJSJZhYWFPpRljDHGF74EfS7Qz+v3vsA+H7fJBXJVNdOz/Dmc4G9AVReqaqqqpsbF2bcNGWNMa/El6NcCQ0VkkIiEANOAl+tt8zJwgzjSgSOqmqeq+cAeERnu2e5C4LPWKt4YY0zzfJoCQUS+CTyAM7xyqar+UURuAlDVRz3DK/8OTMEZXjlLVbM8+6bgDK8MAXZ41h1q8CJ1X68Q2HWaf1MscOA0921LVlfLWF0tY3W1TCDWNUBVG+0O6ZBz3ZwJEclqar4Hf7K6Wsbqahmrq2W6Wl02BYIxxgQ4C3pjjAlwgRj0C/1dQBOsrpaxulrG6mqZLlVXwPXRG2OMqSsQW/TGGGO8WNAbY0yA6zRBLyJTRGSriGwTkQbTKHhu1nrQs/4TETnH133buK7rPPV8IiIfikiy17qdIrJJRLJFpFW/O9GHuiaLyBHPa2eLyG983beN61rgVdOnIlIlIj0969ry/VoqIgUi8mkT6/11fDVXl7+Or+bq8tfx1Vxd/jq++onI2yKSIyKbReTWRrZpu2NMVTv8A+dGre04k6SFABuBkfW2+SbwGs68O+lApq/7tnFd5wE9PD9fWlOX5/edQKyf3q/JwD9PZ9+2rKve9t8C3mrr98vz3JNwpuf4tIn17X58+VhXux9fPtbV7seXL3X58fjqDZzj+TkS+Lw9M6yztOhPe6pkH/dts7pU9UOtvRM4A2ceoLZ2Jn+zX9+veqYDK1rptU9JVd8DDp5iE38cX83W5afjy5f3qyl+fb/qac/jK089X8akqiU438tRfxbgNjvGOkvQn8lUyb7s25Z1eZuDc8auocDrIrJOROa3Uk0tqetcEdkoIq+JyKgW7tuWdSEi3XCm1Hjea3FbvV++8Mfx1VLtdXz5qr2PL5/58/gSkYHAOCCz3qo2O8Y6y3fGnslUyb7se7p8fm4R+TrO/4hf8Vp8vqruE5F44N8issXTImmPutbjzI1xVJy5jF7E+YawDvF+4Xys/kBVvVtnbfV++cIfx5fP2vn48oU/jq+W8MvxJSIROCeXn6hqcf3VjezSKsdYZ2nRn+lUyc3t25Z1ISJjcSZ2u0JVi2qWq+o+z78FwAs4H9HapS5VLVbVo56fXwWCRSTWl33bsi4v06j3sboN3y9f+OP48okfjq9m+en4aol2P75EJBgn5Jep6upGNmm7Y6wtLjy09gPnk8cOYBC1FyNG1dvmMupeyPjY133buK7+wDbgvHrLuwORXj9/CExpx7oSqb1hbiKw2/Pe+fX98mwXjdPP2r093i+v1xhI0xcX2/348rGudj++fKyr3Y8vX+ry1/Hl+dufBB44xTZtdox1iq4bVa0UkZuBNdROlbxZvKZKBl7FuWq9Dc9Uyafatx3r+g3QC/hfEQGoVGd2ugTgBc+yIGC5qv6rHeu6BviBiFQCpcA0dY4qf79fAFcBr6vqMa/d2+z9AhCRFTgjRWJFJBf4LRDsVVe7H18+1tXux5ePdbX78eVjXeCH4ws4H5gBbBKRbM+yX+KcqNv8GLMpEIwxJsB1lj56Y4wxp8mC3hhjApwFvTHGBDgLemOMCXAW9MYYE+As6I0xJsBZ0BtjTID7fxtypFWEvjHWAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 画图\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.plot(range(len(losses)), losses, marker='o',label='losses')\n",
    "plt.plot(range(len(eval_losses)), eval_losses, marker='*',label='eval_losses')\n",
    "plt.legend() #图例\n",
    "plt.title(\"loss\") #标题\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "55d33dfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存模型 .save('模型.h5')\n",
    "\n",
    "torch.save(model_2,'pytorch_number_model_2.h5')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d1ff6c5",
   "metadata": {},
   "source": [
    "# 5.模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b358f7b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "from torchkeras import summary,Model\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "\n",
    "# 加载模型\n",
    "\n",
    "#方法一训练模型\n",
    "model = torch.load('pytorch_number_model.h5')\n",
    "#方法二训练模型\n",
    "model_2 = torch.load('pytorch_number_model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9238dfb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "图片原始格式: (4, 28, 28)\n",
      "标签： [1 0 4 1]\n",
      "图片转换成tensor格式后： torch.Size([4, 28, 28])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<function matplotlib.pyplot.show(close=None, block=None)>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATsAAAD7CAYAAAAVQzPHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAW6ElEQVR4nO3de3RV1Z0H8O+PJCQCWok8jCEKrUHBR1HDo2pblaJobUFbH3TGMh1abMGqq85S2tqHtrXYzmJ1fBsLE6YqasUW2mo7mIGi1fIUEYgQVNBgJFB5RAohCb/5I8dz7r7mJDf3nse9d38/a2Xdve8+yf5pfvxyzr7nIaoKIqJ81yvuAIiIosBiR0RWYLEjIiuw2BGRFVjsiMgKLHZEZIWMip2ITBSRzSKyVURmBRUUUdyY2/lH0j3PTkQKAGwBMAFAA4BVAKao6qbgwiOKHnM7PxVm8L1jAGxV1TcBQESeADAJgG9C9JZiLUHfDKakoDRjz25VHRh3HFmKuZ2jDuEADmuLdDaWSbErB/BOQr8BwNiuvqEEfTFWxmcwJQXleX16e9wxZDHmdo5aobW+Y5kUu86q50eOiUVkOoDpAFCCPhlMRxQZ5nYeyuQDigYAFQn9IQDeTd5IVatVtUpVq4pQnMF0RJFhbuehTIrdKgCVIjJMRHoDuBbA4mDCIooVczsPpX0Yq6ptInIDgL8AKAAwT1U3BhYZUUyY2/kpkzU7qOqzAJ4NKBairMHczj+8goKIrMBiR0RWYLEjIiuw2BGRFVjsiMgKLHZEZAUWOyKyQkbn2dFHyTmnue0/Lf6NMXbGQze47YqfvBRZTEQfKjj2Y0Z/830fd9uvX/hrY+z2pnOM/mv/Mtxtt2/aEkJ04eKeHRFZgcWOiKzAYkdEVuCaXcCaRh/jttvQboz1eTe9W+ATBeXIsCFG/7ULHnbbrUnp+dNBa4z+J684121XcM2OiCg7sdgRkRV4GBuwPWd6h64NbS3G2HFzX446HCIUVniHrsOqt8YYSby4Z0dEVmCxIyIrsNgRkRW4ZpchPW+U0X/h8jlu+7PLv22MnYxXogiJLPf2D881+udM9J7t/YuyF9L+uf3O3eW23/mBOceA9W1u+6hFK9OeI0zcsyMiK7DYEZEVeBibofdHHmX0ywq8J8OXP10UdThEWH/9vUa/Vdt9tuyZZZ98zOt80hz73YEytz2vebIxVvh/5pUYceGeHRFZgcWOiKzAYkdEVuCaXYbGzzAvAfv9gWPddr9lm42xYFZOiD6qaJm3ZlYkBYH8zFcOHzH621oHuu0r+r5vjF3dr8lr/6baGLu83LzjcVy63bMTkXki0iQiGxLeKxWRJSJS77z2DzdMouAxt+2SymFsDYCJSe/NAlCrqpUAap0+Ua6pAXPbGt0exqrqchEZmvT2JAAXOO35AJYBuC3IwLJVwWmnGP27Bi0w+nP3e3eYaN+7L5KYKD25nNsHJ48x+l8r+63bTj7VJNVTT06v/abRH1hbbPSL93k/57sXmPtJr111j+/Pbfiud7XFkJ/H96CpdD+gGKyqjQDgvA4KLiSiWDG381ToH1CIyHQA0wGgBH262ZoodzC3c0u6e3Y7RaQMAJzXJr8NVbVaVatUtaoIxX6bEWUL5naeSnfPbjGAqQBmO6+LAosoy+2YcFyX42uaT0roHQw3GApD1uZ24nrxT+eYp3dU9T6cuKXvz0i8rAsAbl/6Jbc94tbXjbH2/ft9f84p9cON/sovlrjtMcWHjLHnvvULt31xya3G2NC7vEvJtMW8s3fQUjn1ZAGAlwGcIiINIjINHYkwQUTqAUxw+kQ5hbltl1Q+jZ3iMzQ+4FiIIsXctguvoOih/SNbuxxfd98ot30s+IAdCs6R3t4/V/OwtWv/vt07lbD5GvMuPcMbvBtt9uQKn/ak58bOqPFOW1l9/a+MsbICb86108yxLz0z1W3rq3U9iKDneG0sEVmBxY6IrMBiR0RW4JpdClouHe22F11s3gX2zt3mHR1KF6532+Y9I4ii8b2dVUZ//9e906XaG+pDmXPowt1u+weTxxljs49fFcqcPcU9OyKyAosdEVmBh7EpaLjI+990Zu8SY2zqtjOM/qAD5lnoRGHo6gad68/WpHfCOXQ1iLjNwl7mAk5Xsb57h9c+fnLQQZm4Z0dEVmCxIyIrsNgRkRW4ZpeCgad7d/lpV3M9onARH1FA0dj8Le+eeUE9+Doo2670Tm95euBKY6xVCxLaZtwn/Mhrh32qFvfsiMgKLHZEZAUWOyKyAtfsOlE47CSj/5+neE9uemRfhTFWOo+3caJo3P7pP8Q6f2GF9+S85nNOMMYe+toDKf2MlS3meapyuC3zwFLEPTsisgKLHRFZgYexnai/3txFH5fw4KhvrL3QGKvAhihCIordpjuOd9sbL74v5e9b+MEAt/3gf1xljJXUrUzePDTcsyMiK7DYEZEVWOyIyApcs+vEkYpDvmMH95b4jhHlk6Jl5gO1f162MK2fU7PjXLdd8ofo1uiScc+OiKzAYkdEVuBhbCceGPuo71j5c/53XSUKU4F49wXp6u6/+78yznfsjjvnGv0Lj/Jfskmew7xjSer/DvSiHSlvGybu2RGRFbotdiJSISJLRaRORDaKyE3O+6UiskRE6p1X3tiNcgpz2y6p7Nm1AbhFVUcAGAdgpoiMBDALQK2qVgKodfpEuYS5bZFu1+xUtRFAo9NuFpE6AOUAJgG4wNlsPoBlAG4LJcoIHPrCGLd9fknyx+Nc2sxHuZbbs5/8stu+etqvfLdb/sv7jX5XdzVuTX4QWRdSvTvy6bXfNPqVWJv6JCHq0ZqdiAwFcBaAFQAGO8nyYdIMCjw6oogwt/NfysVORPoBWAjgZlXd34Pvmy4iq0VkdSta0omRKFTMbTukdHwmIkXoSIbHVPUZ5+2dIlKmqo0iUgagqbPvVdVqANUAcIyU9mCnOVpvf9ELrVjM/y137vYehN1v0RpjLGv/gygluZTbH39yt9te+a/mlTxjiv1PIQlK4o03q9/7rDG2Z4Z3R5RT39pqjGXLo4FS+TRWAMwFUKeqcxKGFgOY6rSnAlgUfHhE4WFu2yWVPbvzAFwH4DURWee89z0AswE8JSLTALwN4KrOv50oazG3LZLKp7EvAhCf4fHBhkMUHea2Xaw9p6LgmGOM/m3nPeu77ePPfcZtf7yND9iheLRv2uK2f/idrxtj73zBu5Rsy6UPhzL/jHneKSUVP3spaXRPKHMGiZeLEZEVWOyIyArWHsYeaTHPi9r0T+8hO5/bUWWMVd610W1ny8foZLejFplX+QxP+Lz4M1NmGmNF/7bTbf/5tCeNsYs3XOu2j9SY505r0mrm0HW73HYu/jvgnh0RWYHFjoiswGJHRFawds1Ok9bsNics0/XGdmMsF9cnyF7HLPi7+cYCr3kFxhhDffFmQu9NdCXX/x1wz46IrMBiR0RWYLEjIiuw2BGRFVjsiMgKLHZEZAUWOyKyAosdEVmBxY6IrMBiR0RWYLEjIiuw2BGRFVjsiMgKohrdY55FZBeA7QAGANjdzeZRsTWWk1R1YERz5T0ntw8ge3IJsDO3ffM60mLnTiqyWlWrut8yfIyFgpJtv79siicbYuFhLBFZgcWOiKwQV7GrjmnezjAWCkq2/f6yKZ7YY4llzY6IKGo8jCUiK0Ra7ERkoohsFpGtIjIryrmd+eeJSJOIbEh4r1RElohIvfPaP6JYKkRkqYjUichGEbkpzngoM3HmNvM6NZEVOxEpAHA/gEsBjAQwRURGRjW/owbAxKT3ZgGoVdVKALVOPwptAG5R1REAxgGY6fz/iCseSlMW5HYNmNfdinLPbgyArar6pqoeBvAEgEkRzg9VXQ7g/aS3JwGY77TnA5gcUSyNqrrWaTcDqANQHlc8lJFYc5t5nZooi105gHcS+g3Oe3EbrKqNQMcvCsCgqAMQkaEAzgKwIhvioR7LxtyOPY+yLa+jLHbSyXvWfxQsIv0ALARws6rujzseSgtzO0k25nWUxa4BQEVCfwiAdyOc389OESkDAOe1KaqJRaQIHQnxmKo+E3c8lLZszG3mdZIoi90qAJUiMkxEegO4FsDiCOf3sxjAVKc9FcCiKCYVEQEwF0Cdqs6JOx7KSDbmNvM6mapG9gXgMgBbALwB4PtRzu3MvwBAI4BWdPw1ngbgOHR8OlTvvJZGFMv56DjUWQ9gnfN1WVzx8Cvj32dsuc28Tu2LV1AQkRV4BQURWYHFjoiskFGxi/vyL6KwMLfzT9prds4lMlsATEDHougqAFNUdVNw4RFFj7mdnwoz+F73EhkAEJEPL5HxTYjeUqwl6JvBlBSUZuzZrXwGhR/mdo46hAM4rC2dneSdUbHr7BKZsV19Qwn6YqyMz2BKCsrz+vT2uGPIYsztHLVCa33HMil2KV0iIyLTAUwHgBL0yWA6osgwt/NQJh9QpHSJjKpWq2qVqlYVoTiD6Ygiw9zOQ5kUu2y8RIYoCMztPJT2YayqtonIDQD+AqAAwDxV3RhYZEQxYW7np0zW7KCqzwJ4NqBYiLIGczv/8AoKIrICix0RWYHFjoiswGJHRFZgsSMiK7DYEZEVWOyIyAosdkRkBRY7IrICix0RWYHFjoiskNG1sbms/cKzjf4N1U+57QcrTw59/uZrxhn9Y9ftdtvtm7eGPj9RT+396qfc9orZDxpjI++f4bZPvHulMaZtbeEGliLu2RGRFVjsiMgK1h7Gbr/EvLNsacEHkc7/3ucPG/3W67y/O6WXRxoKUacKy08w+j/54a99t9008wG3fek9nzbGtLk52MDSxD07IrICix0RWYHFjoisYNWanRT1dtsXXbQuvkAAHP1KidG/etpf3fbSY4cYY+1790USE1GipktOMvoX92n13fbs1de47YEfbAktpkxwz46IrMBiR0RWsOowtvkK76qJe8rvNcZG/P4Gt12JFaHH0tLffMD8jf1fd9vLjh5hbszDWIpArz59jP4lN76Y8vcWP9Hf66j6bxgj7tkRkRVY7IjICix2RGSFvF6z0/NGGf377/4vt/3ofvNj9VNv9z4ubw81qg6funhDBLMQpa7lXHOt+KeD5vpu+88j5uWOxzz+91BiClK3e3YiMk9EmkRkQ8J7pSKyRETqndf+Xf0MomzE3LZLKoexNQAmJr03C0CtqlYCqHX6RLmmBsxta3R7GKuqy0VkaNLbkwBc4LTnA1gG4LYgAwvCnu/+0+gPKfRuIvidb3/eGCvasyb0eArLjnfb/33in42xVuXyadRyObfD8NaVBSlv++X6yUnvvBtoLGFI91/YYFVtBADndVBwIRHFirmdp0L/gEJEpgOYDgAl6NPN1kS5g7mdW9Lds9spImUA4Lw2+W2oqtWqWqWqVUUo9tuMKFswt/NUunt2iwFMBTDbeV0UWEQZ+sc3vIeC/PaMXxpj/7PvTLdd9Hz4a3TJNt1Z4bZb1TzBZeq2z7nt9qZdkcVEH5G1uR22z49+tcvxfUcOuu3WHw82xnrlw5qdiCwA8DKAU0SkQUSmoSMRJohIPYAJTp8opzC37ZLKp7FTfIbGBxwLUaSY23bJuysoek32nr96QqG5jjL3ce+UqiF4KfRYCk47xeg/Ov5ht92i5o0Q354z3G33bQn/ritEANBy2Wi3fV/5I11u25Dw+Ndef30lrJBCw5O7iMgKLHZEZAUWOyKyQs6v2RUMHGj0bx/+J99th9wV/jpdotdnHGv0q4q9003u3zPSGOu7kOt0FL2do4tS3vYLf7zZbUdxN++gcc+OiKzAYkdEVsj5w1jpYz5/9ZI+3sNpxqz6qjF2POoiielDA4a+7zv22FtV5rbIzmdtUn7rfdYe37G6w+Zdg069xzutK4ob3AaNe3ZEZAUWOyKyAosdEVkh59fsjry/1+j/ZJf3IOyvfGK1Mba87BNuu63xvVDiKTzJu7PJ30Y9kTTq/W05+PcBSWNcs6PwHbp8jNFfPfrBhJ55p+LNreZ9S9u3vBFWWJHgnh0RWYHFjoiswGJHRFbI/TW75maj/787TnXbL4x63Bhr/OPHvLGHP4V07B2pRr/f0H1Gf9wJ27zYcMT354j6DhGF5uAAc12uSPyfKHbrmiuN/jCsDyWmqHDPjoiswGJHRFbI+cPYZP3v8C4f++yPzbtu/+70Grd9949eTuvnr24xd/vbk/5eVPU+nNAT359z4r2vGX3/A16i4LRM3us7lnx52JBfp35HlFzAPTsisgKLHRFZgcWOiKyQd2t2WOmthX3sMnPougtudNt7K9N7gvtxj3S91rfjmdPc9pqxNb7bJZ8yQxSWguHeZZKrRz+aPOq2nvvgdGMkjgfJh4l7dkRkBRY7IrJC/h3GdqFg2Vq3fdyycOY4uO1orzPWfzs9b5TRl7+tCyUeop0Xencv6eqKifuWTjD6ufhQna50u2cnIhUislRE6kRko4jc5LxfKiJLRKTeee0ffrhEwWFu2yWVw9g2ALeo6ggA4wDMFJGRAGYBqFXVSgC1Tp8olzC3LdJtsVPVRlVd67SbAdQBKAcwCcB8Z7P5ACaHFCNRKJjbdunRmp2IDAVwFoAVAAaraiPQkTQiMqir77VGwhVivbr4W8I1uuySz7l9qNT/ssU1Ld7ljSPubjDG2kKLKB4pfxorIv0ALARws6ru78H3TReR1SKyuhUt6cRIFCrmth1SKnYiUoSOZHhMVZ9x3t4pImXOeBmAps6+V1WrVbVKVauKkN6JvERhYW7bo9vDWBERAHMB1KnqnIShxQCmApjtvC4KJcJck3BTzq5u3knxsyW3B120w3ds8f6z3Hb7rt2+2+WDVNbszgNwHYDXRGSd89730JEIT4nINABvA7gqlAiJwsPctki3xU5VX4T/jdnGBxsOUXSY23bh5WJEZAWrLheLwpES/3W6Xe38xI7CJ8XmhyWTTnjVd9t/HO7ntrUlv/OTe3ZEZAUWOyKyAg9jA/boxIfcdt1h85B2Ss2tbvtEvBRZTGSZ9najW113vtu++dxtxtiyd0522+XYGGpYceOeHRFZgcWOiKzAYkdEVuCaXcDufOuLbvvAA+XG2IkLuU5H4dM2834lQ2cdcNsjfn6dMSbrjoYtuGdHRFZgsSMiK/AwNmjjvRsg9kVDFxsSRaN961tu+0SLb2nAPTsisgKLHRFZgcWOiKzAYkdEVmCxIyIrsNgRkRVY7IjICix2RGQFFjsisgKLHRFZQVS1+62CmkxkF4DtAAYAyJYn8toay0mqOjCiufKek9sHkD25BNiZ2755HWmxcycVWa2qVZFP3AnGQkHJtt9fNsWTDbHwMJaIrMBiR0RWiKvYVcc0b2cYCwUl235/2RRP7LHEsmZHRBQ1HsYSkRUiLXYiMlFENovIVhGZFeXczvzzRKRJRDYkvFcqIktEpN557R9RLBUislRE6kRko4jcFGc8lJk4c5t5nZrIip2IFAC4H8ClAEYCmCIiI6Oa31EDYGLSe7MA1KpqJYBapx+FNgC3qOoIAOMAzHT+f8QVD6UpC3K7BszrbkW5ZzcGwFZVfVNVDwN4AsCkCOeHqi4H8H7S25MAzHfa8wFMjiiWRlVd67SbAdQBKI8rHspIrLnNvE5NlMWuHMA7Cf0G5724DVbVRqDjFwVgUNQBiMhQAGcBWJEN8VCPZWNux55H2ZbXURY76eQ96z8KFpF+ABYCuFlV98cdD6WFuZ0kG/M6ymLXAKAioT8EwLsRzu9np4iUAYDz2hTVxCJShI6EeExVn4k7HkpbNuY28zpJlMVuFYBKERkmIr0BXAtgcYTz+1kMYKrTngpgURSTiogAmAugTlXnxB0PZSQbc5t5nUxVI/sCcBmALQDeAPD9KOd25l8AoBFAKzr+Gk8DcBw6Ph2qd15LI4rlfHQc6qwHsM75uiyuePiV8e8zttxmXqf2xSsoiMgKvIKCiKzAYkdEVmCxIyIrsNgRkRVY7IjICix2RGQFFjsisgKLHRFZ4f8BaMAYaCfRwFUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 4 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 手动导入预测数据\n",
    "from tensorflow.keras.datasets import mnist\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "\n",
    "(train_images, train_labels), (test_images, test_labels) = mnist.load_data() \n",
    "\n",
    "# 取前4张图片\n",
    "test_images = test_images[2:6]\n",
    "test_labels = test_labels[2:6]\n",
    "print('图片原始格式:', test_images.shape)\n",
    "print('标签：', test_labels)\n",
    "\n",
    "# 图片格式转换成tensor格式\n",
    "test_images = torch.tensor(test_images)\n",
    "print('图片转换成tensor格式后：', test_images.shape)\n",
    "\n",
    "for i in range(4):\n",
    "    plt.subplot(2,2,i+1)\n",
    "    plt.imshow(test_images[i])\n",
    "plt.show\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "3a0208ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 28, 28])\n",
      "tensor([[[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         ...,\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
      "\n",
      "        [[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         ...,\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
      "\n",
      "        [[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         ...,\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
      "\n",
      "        [[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         ...,\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "         [0., 0., 0.,  ..., 0., 0., 0.]]])\n"
     ]
    }
   ],
   "source": [
    "from torch.autograd import Variable\n",
    "from torch.utils.data import Dataset, TensorDataset, DataLoader\n",
    "\n",
    "# 图片类型转换成FloatTensor(预测和训练数据类型一致)\n",
    "test_images = test_images.type(torch.FloatTensor)/255\n",
    "\n",
    "print(test_images.shape)\n",
    "print(test_images)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "72140cf8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "添加通道维度 torch.Size([4, 1, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "# 为图片格式添加通道维\n",
    "pic = test_images[0:4]\n",
    "\n",
    "pic = pic.reshape(4,1,28,28)\n",
    "print('添加通道维度', pic.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6d6ee79",
   "metadata": {},
   "source": [
    "#### 方法一模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "3e81d8c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ -7.8818,   8.8549,  -5.7712,  -5.2882,   0.1410,  -4.1423,  -4.7644,\n",
      "          -3.7146,  -2.1416,  -9.7490],\n",
      "        [ 11.9379, -18.4716,  -4.9877, -11.8122,  -8.6285,  -7.6033,   1.2691,\n",
      "          -8.3958,  -2.0147,  -2.5103],\n",
      "        [ -5.3669,  -9.1445, -11.3962,  -6.2549,  10.2492,  -8.9682,  -7.9669,\n",
      "          -7.9985,  -2.5321,  -1.9513],\n",
      "        [ -8.7086,   9.1721,  -5.6316,  -7.1824,   0.9074, -11.6737, -11.5757,\n",
      "          -0.2226,  -1.1573,  -7.1985]], grad_fn=<AddmmBackward>)\n",
      "------\n",
      "torch.return_types.max(\n",
      "values=tensor([ 8.8549, 11.9379, 10.2492,  9.1721], grad_fn=<MaxBackward0>),\n",
      "indices=tensor([1, 0, 4, 1]))\n"
     ]
    }
   ],
   "source": [
    "# 预测模式\n",
    "\n",
    "output = model(pic)\n",
    "print(output)\n",
    "\n",
    "print('------')\n",
    "# 按行找到一行内值最大的列号\n",
    "prediction = torch.max(output, dim=1)\n",
    "print(prediction)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d50953cf",
   "metadata": {},
   "source": [
    "#### 方法二模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "918258ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ -7.8818,   8.8549,  -5.7712,  -5.2882,   0.1410,  -4.1423,  -4.7644,\n",
      "          -3.7146,  -2.1416,  -9.7490],\n",
      "        [ 11.9379, -18.4716,  -4.9877, -11.8122,  -8.6285,  -7.6033,   1.2691,\n",
      "          -8.3958,  -2.0147,  -2.5103],\n",
      "        [ -5.3669,  -9.1445, -11.3962,  -6.2549,  10.2492,  -8.9682,  -7.9669,\n",
      "          -7.9985,  -2.5321,  -1.9513],\n",
      "        [ -8.7086,   9.1721,  -5.6316,  -7.1824,   0.9074, -11.6737, -11.5757,\n",
      "          -0.2226,  -1.1573,  -7.1985]], grad_fn=<AddmmBackward>)\n",
      "------\n",
      "torch.return_types.max(\n",
      "values=tensor([ 8.8549, 11.9379, 10.2492,  9.1721], grad_fn=<MaxBackward0>),\n",
      "indices=tensor([1, 0, 4, 1]))\n"
     ]
    }
   ],
   "source": [
    "output = model_2(pic)\n",
    "print(output)\n",
    "\n",
    "print('------')\n",
    "# 按行找到一行内值最大的列号\n",
    "prediction = torch.max(output, dim=1)\n",
    "print(prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f42f8420",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
